-
-
Notifications
You must be signed in to change notification settings - Fork 11.2k
[Attention] Support MTP with DCP #24997
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request aims to enable Multi-Token Parallelism (MTP) with context parallelism by supporting query_len > 1 in the FlashAttention MLA backend. The changes involve removing previous restrictions and adding metadata for a custom causal mask.
I've found a critical issue where the new logic to compute query_base_positions in MLACommonMetadataBuilder is not being used by FlashAttnMLAMetadataBuilder because it overrides the _build_decode method. This will prevent the feature from working as intended. Please see my detailed comment.
| # Compute DCP query base positions if using DCP | ||
| query_base_positions = None | ||
|
|
||
| if self.dcp_world_size > 1: | ||
| query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] | ||
| query_base_positions = (seq_lens_cpu - query_lens).to( | ||
| seq_lens_device.device) | ||
|
|
||
| return MLACommonDecodeMetadata( | ||
| block_table=block_table_tensor, | ||
| seq_lens=seq_lens_device, | ||
| query_base_positions=query_base_positions, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change correctly computes query_base_positions for DCP. However, FlashAttnMLAMetadataBuilder in vllm/v1/attention/backends/mla/flashattn_mla.py overrides _build_decode and does not call this base implementation. As a result, query_base_positions will be None for the FlashAttention MLA backend, and the MTP with context parallelism feature will not work correctly.
To fix this, you should move this logic to FlashAttnMLAMetadataBuilder._build_decode or refactor it so that FlashAttnMLAMetadataBuilder can reuse this logic. For example, you could add the logic to FlashAttnMLAMetadataBuilder._build_decode and pass query_base_positions to the FlashAttnMLADecodeMetadata constructor.
|
This pull request has merge conflicts that must be resolved before it can be |
|
superseded by #25049 |
Purpose
#24453 Added DCP support but did not support
query_len > 1. This PR, which depends on a corresponding FlashAttention PR (vllm-project/flash-attention#92), implements a custom causal mask to take advantage of the FlashAttention MLA backend's capability forquery_len > 1, thereby enabling MTP.Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.